{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# PyTorch DDP Fashion MNIST Training Example\n",
    "\n",
    "This example demonstrates how to train a convolutional neural network (CNN) to classify images using the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset and [PyTorch Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).\n",
    "\n",
    "This notebook walks you through running that example locally, and how to easily scale PyTorch DDP across multiple nodes with Kubeflow TrainJob."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Install the Kubeflow SDK\n",
    "\n",
    "You need to install the Kubeflow SDK to interact with Kubeflow Trainer APIs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# !pip install -U kubeflow"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install the PyTorch Dependencies\n",
    "\n",
    "You also need to install PyTorch and Torchvision to be able to run the example locally:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torch==2.7.1\n",
    "!pip install torchvision==0.22.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the Training Function\n",
    "\n",
    "The first step is to create function to train CNN model using Fashion MNIST data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-03T13:19:46.917723Z",
     "iopub.status.busy": "2025-09-03T13:19:46.917308Z",
     "iopub.status.idle": "2025-09-03T13:19:46.935181Z",
     "shell.execute_reply": "2025-09-03T13:19:46.934697Z",
     "shell.execute_reply.started": "2025-09-03T13:19:46.917698Z"
    }
   },
   "outputs": [],
   "source": [
    "def train_fashion_mnist():\n",
    "    import os\n",
    "\n",
    "    import torch\n",
    "    import torch.distributed as dist\n",
    "    import torch.nn.functional as F\n",
    "    from torch import nn\n",
    "    from torch.utils.data import DataLoader, DistributedSampler\n",
    "    from torchvision import datasets, transforms\n",
    "\n",
    "    # Define the PyTorch CNN model to be trained\n",
    "    class Net(nn.Module):\n",
    "        def __init__(self):\n",
    "            super(Net, self).__init__()\n",
    "            self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
    "            self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
    "            self.fc1 = nn.Linear(4 * 4 * 50, 500)\n",
    "            self.fc2 = nn.Linear(500, 10)\n",
    "\n",
    "        def forward(self, x):\n",
    "            x = F.relu(self.conv1(x))\n",
    "            x = F.max_pool2d(x, 2, 2)\n",
    "            x = F.relu(self.conv2(x))\n",
    "            x = F.max_pool2d(x, 2, 2)\n",
    "            x = x.view(-1, 4 * 4 * 50)\n",
    "            x = F.relu(self.fc1(x))\n",
    "            x = self.fc2(x)\n",
    "            return F.log_softmax(x, dim=1)\n",
    "\n",
    "    # Use NCCL if a GPU is available, otherwise use Gloo as communication backend.\n",
    "    device, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n",
    "    print(f\"Using Device: {device}, Backend: {backend}\")\n",
    "\n",
    "    # Setup PyTorch distributed.\n",
    "    local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n",
    "    dist.init_process_group(backend=backend)\n",
    "    print(\n",
    "        \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n",
    "            dist.get_world_size(),\n",
    "            dist.get_rank(),\n",
    "            local_rank,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    # Create the model and load it into the device.\n",
    "    device = torch.device(f\"{device}:{local_rank}\")\n",
    "    model = nn.parallel.DistributedDataParallel(Net().to(device))\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n",
    "\n",
    "    \n",
    "    # Download FashionMNIST dataset only on local_rank=0 process.\n",
    "    if local_rank == 0:\n",
    "        dataset = datasets.FashionMNIST(\n",
    "            \"./data\",\n",
    "            train=True,\n",
    "            download=True,\n",
    "            transform=transforms.Compose([transforms.ToTensor()]),\n",
    "        )\n",
    "    dist.barrier()\n",
    "    dataset = datasets.FashionMNIST(\n",
    "        \"./data\",\n",
    "        train=True,\n",
    "        download=False,\n",
    "        transform=transforms.Compose([transforms.ToTensor()]),\n",
    "    )\n",
    "\n",
    "\n",
    "    # Shard the dataset accross workers.\n",
    "    train_loader = DataLoader(\n",
    "        dataset,\n",
    "        batch_size=100,\n",
    "        sampler=DistributedSampler(dataset)\n",
    "    )\n",
    "\n",
    "    # TODO(astefanutti): add parameters to the training function\n",
    "    dist.barrier()\n",
    "    for epoch in range(1, 3):\n",
    "        model.train()\n",
    "\n",
    "        # Iterate over mini-batches from the training set\n",
    "        for batch_idx, (inputs, labels) in enumerate(train_loader):\n",
    "            # Copy the data to the GPU device if available\n",
    "            inputs, labels = inputs.to(device), labels.to(device)\n",
    "            # Forward pass\n",
    "            outputs = model(inputs)\n",
    "            loss = F.nll_loss(outputs, labels)\n",
    "            # Backward pass\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            if batch_idx % 10 == 0 and dist.get_rank() == 0:\n",
    "                print(\n",
    "                    \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n",
    "                        epoch,\n",
    "                        batch_idx * len(inputs),\n",
    "                        len(train_loader.dataset),\n",
    "                        100.0 * batch_idx / len(train_loader),\n",
    "                        loss.item(),\n",
    "                    )\n",
    "                )\n",
    "\n",
    "    # Wait for the distributed training to complete\n",
    "    dist.barrier()\n",
    "    if dist.get_rank() == 0:\n",
    "        print(\"Training is finished\")\n",
    "\n",
    "    # Finally clean up PyTorch distributed\n",
    "    dist.destroy_process_group()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dry-run the Training Locally\n",
    "\n",
    "We are going to download Fashion MNIST Dataset and start local training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using Device: cpu, Backend: gloo\n",
      "Distributed Training for WORLD_SIZE: 1, RANK: 0, LOCAL_RANK: 0\n",
      "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.300957\n",
      "Train Epoch: 1 [1000/60000 (2%)]\tLoss: 2.068733\n",
      "Train Epoch: 1 [2000/60000 (3%)]\tLoss: 1.136209\n",
      "Train Epoch: 1 [3000/60000 (5%)]\tLoss: 0.897645\n",
      "Train Epoch: 1 [4000/60000 (7%)]\tLoss: 0.868213\n",
      "Train Epoch: 1 [5000/60000 (8%)]\tLoss: 0.943094\n",
      "Train Epoch: 1 [6000/60000 (10%)]\tLoss: 0.785388\n",
      "Train Epoch: 1 [7000/60000 (12%)]\tLoss: 0.772929\n",
      "Train Epoch: 1 [8000/60000 (13%)]\tLoss: 0.612156\n",
      "Train Epoch: 1 [9000/60000 (15%)]\tLoss: 0.783617\n",
      "Train Epoch: 1 [10000/60000 (17%)]\tLoss: 0.608218\n",
      "Train Epoch: 1 [11000/60000 (18%)]\tLoss: 0.595062\n",
      "Train Epoch: 1 [12000/60000 (20%)]\tLoss: 0.612815\n",
      "Train Epoch: 1 [13000/60000 (22%)]\tLoss: 0.523903\n",
      "Train Epoch: 1 [14000/60000 (23%)]\tLoss: 0.673269\n",
      "Train Epoch: 1 [15000/60000 (25%)]\tLoss: 0.675361\n",
      "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.591610\n",
      "Train Epoch: 1 [17000/60000 (28%)]\tLoss: 0.515040\n",
      "Train Epoch: 1 [18000/60000 (30%)]\tLoss: 0.535079\n",
      "Train Epoch: 1 [19000/60000 (32%)]\tLoss: 0.721351\n",
      "Train Epoch: 1 [20000/60000 (33%)]\tLoss: 0.413175\n",
      "Train Epoch: 1 [21000/60000 (35%)]\tLoss: 0.407277\n",
      "Train Epoch: 1 [22000/60000 (37%)]\tLoss: 0.356355\n",
      "Train Epoch: 1 [23000/60000 (38%)]\tLoss: 0.685514\n",
      "Train Epoch: 1 [24000/60000 (40%)]\tLoss: 0.439642\n",
      "Train Epoch: 1 [25000/60000 (42%)]\tLoss: 0.399821\n",
      "Train Epoch: 1 [26000/60000 (43%)]\tLoss: 0.556694\n",
      "Train Epoch: 1 [27000/60000 (45%)]\tLoss: 0.326490\n",
      "Train Epoch: 1 [28000/60000 (47%)]\tLoss: 0.401890\n",
      "Train Epoch: 1 [29000/60000 (48%)]\tLoss: 0.449772\n",
      "Train Epoch: 1 [30000/60000 (50%)]\tLoss: 0.427323\n",
      "Train Epoch: 1 [31000/60000 (52%)]\tLoss: 0.400165\n",
      "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.334357\n",
      "Train Epoch: 1 [33000/60000 (55%)]\tLoss: 0.298294\n",
      "Train Epoch: 1 [34000/60000 (57%)]\tLoss: 0.551655\n",
      "Train Epoch: 1 [35000/60000 (58%)]\tLoss: 0.456408\n",
      "Train Epoch: 1 [36000/60000 (60%)]\tLoss: 0.409366\n",
      "Train Epoch: 1 [37000/60000 (62%)]\tLoss: 0.449682\n",
      "Train Epoch: 1 [38000/60000 (63%)]\tLoss: 0.433943\n",
      "Train Epoch: 1 [39000/60000 (65%)]\tLoss: 0.403032\n",
      "Train Epoch: 1 [40000/60000 (67%)]\tLoss: 0.427653\n",
      "Train Epoch: 1 [41000/60000 (68%)]\tLoss: 0.596136\n",
      "Train Epoch: 1 [42000/60000 (70%)]\tLoss: 0.451446\n",
      "Train Epoch: 1 [43000/60000 (72%)]\tLoss: 0.415507\n",
      "Train Epoch: 1 [44000/60000 (73%)]\tLoss: 0.431509\n",
      "Train Epoch: 1 [45000/60000 (75%)]\tLoss: 0.269870\n",
      "Train Epoch: 1 [46000/60000 (77%)]\tLoss: 0.603037\n",
      "Train Epoch: 1 [47000/60000 (78%)]\tLoss: 0.361142\n",
      "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.533713\n",
      "Train Epoch: 1 [49000/60000 (82%)]\tLoss: 0.334492\n",
      "Train Epoch: 1 [50000/60000 (83%)]\tLoss: 0.360594\n",
      "Train Epoch: 1 [51000/60000 (85%)]\tLoss: 0.479819\n",
      "Train Epoch: 1 [52000/60000 (87%)]\tLoss: 0.408797\n",
      "Train Epoch: 1 [53000/60000 (88%)]\tLoss: 0.402639\n",
      "Train Epoch: 1 [54000/60000 (90%)]\tLoss: 0.291429\n",
      "Train Epoch: 1 [55000/60000 (92%)]\tLoss: 0.308326\n",
      "Train Epoch: 1 [56000/60000 (93%)]\tLoss: 0.503240\n",
      "Train Epoch: 1 [57000/60000 (95%)]\tLoss: 0.449824\n",
      "Train Epoch: 1 [58000/60000 (97%)]\tLoss: 0.339513\n",
      "Train Epoch: 1 [59000/60000 (98%)]\tLoss: 0.260989\n",
      "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.379833\n",
      "Train Epoch: 2 [1000/60000 (2%)]\tLoss: 0.421474\n",
      "Train Epoch: 2 [2000/60000 (3%)]\tLoss: 0.298147\n",
      "Train Epoch: 2 [3000/60000 (5%)]\tLoss: 0.302185\n",
      "Train Epoch: 2 [4000/60000 (7%)]\tLoss: 0.407515\n",
      "Train Epoch: 2 [5000/60000 (8%)]\tLoss: 0.372201\n",
      "Train Epoch: 2 [6000/60000 (10%)]\tLoss: 0.403835\n",
      "Train Epoch: 2 [7000/60000 (12%)]\tLoss: 0.403082\n",
      "Train Epoch: 2 [8000/60000 (13%)]\tLoss: 0.429058\n",
      "Train Epoch: 2 [9000/60000 (15%)]\tLoss: 0.427616\n",
      "Train Epoch: 2 [10000/60000 (17%)]\tLoss: 0.353761\n",
      "Train Epoch: 2 [11000/60000 (18%)]\tLoss: 0.392035\n",
      "Train Epoch: 2 [12000/60000 (20%)]\tLoss: 0.315083\n",
      "Train Epoch: 2 [13000/60000 (22%)]\tLoss: 0.303653\n",
      "Train Epoch: 2 [14000/60000 (23%)]\tLoss: 0.477725\n",
      "Train Epoch: 2 [15000/60000 (25%)]\tLoss: 0.311283\n",
      "Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.320934\n",
      "Train Epoch: 2 [17000/60000 (28%)]\tLoss: 0.344658\n",
      "Train Epoch: 2 [18000/60000 (30%)]\tLoss: 0.252863\n",
      "Train Epoch: 2 [19000/60000 (32%)]\tLoss: 0.433338\n",
      "Train Epoch: 2 [20000/60000 (33%)]\tLoss: 0.401821\n",
      "Train Epoch: 2 [21000/60000 (35%)]\tLoss: 0.277403\n",
      "Train Epoch: 2 [22000/60000 (37%)]\tLoss: 0.324770\n",
      "Train Epoch: 2 [23000/60000 (38%)]\tLoss: 0.450259\n",
      "Train Epoch: 2 [24000/60000 (40%)]\tLoss: 0.280156\n",
      "Train Epoch: 2 [25000/60000 (42%)]\tLoss: 0.273340\n",
      "Train Epoch: 2 [26000/60000 (43%)]\tLoss: 0.405919\n",
      "Train Epoch: 2 [27000/60000 (45%)]\tLoss: 0.270347\n",
      "Train Epoch: 2 [28000/60000 (47%)]\tLoss: 0.389775\n",
      "Train Epoch: 2 [29000/60000 (48%)]\tLoss: 0.266520\n",
      "Train Epoch: 2 [30000/60000 (50%)]\tLoss: 0.369550\n",
      "Train Epoch: 2 [31000/60000 (52%)]\tLoss: 0.307309\n",
      "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.282144\n",
      "Train Epoch: 2 [33000/60000 (55%)]\tLoss: 0.259714\n",
      "Train Epoch: 2 [34000/60000 (57%)]\tLoss: 0.455187\n",
      "Train Epoch: 2 [35000/60000 (58%)]\tLoss: 0.418436\n",
      "Train Epoch: 2 [36000/60000 (60%)]\tLoss: 0.347754\n",
      "Train Epoch: 2 [37000/60000 (62%)]\tLoss: 0.285930\n",
      "Train Epoch: 2 [38000/60000 (63%)]\tLoss: 0.298169\n",
      "Train Epoch: 2 [39000/60000 (65%)]\tLoss: 0.315216\n",
      "Train Epoch: 2 [40000/60000 (67%)]\tLoss: 0.207120\n",
      "Train Epoch: 2 [41000/60000 (68%)]\tLoss: 0.447531\n",
      "Train Epoch: 2 [42000/60000 (70%)]\tLoss: 0.356669\n",
      "Train Epoch: 2 [43000/60000 (72%)]\tLoss: 0.295995\n",
      "Train Epoch: 2 [44000/60000 (73%)]\tLoss: 0.306820\n",
      "Train Epoch: 2 [45000/60000 (75%)]\tLoss: 0.209821\n",
      "Train Epoch: 2 [46000/60000 (77%)]\tLoss: 0.474326\n",
      "Train Epoch: 2 [47000/60000 (78%)]\tLoss: 0.292024\n",
      "Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.554571\n",
      "Train Epoch: 2 [49000/60000 (82%)]\tLoss: 0.313117\n",
      "Train Epoch: 2 [50000/60000 (83%)]\tLoss: 0.247606\n",
      "Train Epoch: 2 [51000/60000 (85%)]\tLoss: 0.385026\n",
      "Train Epoch: 2 [52000/60000 (87%)]\tLoss: 0.369550\n",
      "Train Epoch: 2 [53000/60000 (88%)]\tLoss: 0.326454\n",
      "Train Epoch: 2 [54000/60000 (90%)]\tLoss: 0.317689\n",
      "Train Epoch: 2 [55000/60000 (92%)]\tLoss: 0.296784\n",
      "Train Epoch: 2 [56000/60000 (93%)]\tLoss: 0.399627\n",
      "Train Epoch: 2 [57000/60000 (95%)]\tLoss: 0.408119\n",
      "Train Epoch: 2 [58000/60000 (97%)]\tLoss: 0.341115\n",
      "Train Epoch: 2 [59000/60000 (98%)]\tLoss: 0.191851\n",
      "Training is finished\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "# Set the Torch Distributed env variables so the training function can be run locally in the Notebook.\n",
    "# See https://pytorch.org/docs/stable/elastic/run.html#environment-variables\n",
    "os.environ[\"RANK\"] = \"0\"\n",
    "os.environ[\"LOCAL_RANK\"] = \"0\"\n",
    "os.environ[\"WORLD_SIZE\"] = \"1\"\n",
    "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
    "os.environ[\"MASTER_PORT\"] = \"1234\"\n",
    "\n",
    "# Run the training function locally.\n",
    "train_fashion_mnist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Scale PyTorch DDP with Kubeflow TrainJob\n",
    "\n",
    "You can use `TrainerClient()` from the Kubeflow SDK to communicate with Kubeflow Trainer APIs and scale your training function across multiple PyTorch training nodes.\n",
    "\n",
    "`TrainerClient()` verifies that you have required access to the Kubernetes cluster.\n",
    "\n",
    "Kubeflow Trainer creates a `TrainJob` resource and automatically sets the appropriate environment variables to set up PyTorch in distributed environment.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-03T13:19:49.832393Z",
     "iopub.status.busy": "2025-09-03T13:19:49.832117Z",
     "iopub.status.idle": "2025-09-03T13:19:51.924613Z",
     "shell.execute_reply": "2025-09-03T13:19:51.924264Z",
     "shell.execute_reply.started": "2025-09-03T13:19:49.832371Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from kubeflow.trainer import CustomTrainer, TrainerClient\n",
    "\n",
    "client = TrainerClient()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## List the Training Runtimes\n",
    "\n",
    "You can get the list of available Training Runtimes to start your TrainJob.\n",
    "\n",
    "Additionally, it might show available accelerator type and number of available resources."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-03T13:19:52.867066Z",
     "iopub.status.busy": "2025-09-03T13:19:52.866470Z",
     "iopub.status.idle": "2025-09-03T13:19:53.794000Z",
     "shell.execute_reply": "2025-09-03T13:19:53.792956Z",
     "shell.execute_reply.started": "2025-09-03T13:19:52.867027Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Runtime(name='deepspeed-distributed', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='deepspeed', num_nodes=1, device='Unknown', device_count='1'), pretrained_model=None)\n",
      "Runtime(name='mlx-distributed', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='mlx', num_nodes=1, device='gpu', device_count='1'), pretrained_model=None)\n",
      "Runtime(name='torch-distributed', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='torch', num_nodes=1, device='Unknown', device_count='Unknown'), pretrained_model=None)\n",
      "Runtime(name='torchtune-llama3.2-1b', trainer=RuntimeTrainer(trainer_type=<TrainerType.BUILTIN_TRAINER: 'BuiltinTrainer'>, framework='torchtune', num_nodes=1, device='gpu', device_count='2'), pretrained_model=None)\n",
      "Runtime(name='torchtune-llama3.2-3b', trainer=RuntimeTrainer(trainer_type=<TrainerType.BUILTIN_TRAINER: 'BuiltinTrainer'>, framework='torchtune', num_nodes=1, device='gpu', device_count='2.0'), pretrained_model=None)\n"
     ]
    }
   ],
   "source": [
    "for runtime in client.list_runtimes():\n",
    "    print(runtime)\n",
    "    if runtime.name == \"torch-distributed\":\n",
    "        torch_runtime = runtime"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the Distributed TrainJob\n",
    "\n",
    "Kubeflow TrainJob will train the above model on 3 PyTorch nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-03T13:19:56.525591Z",
     "iopub.status.busy": "2025-09-03T13:19:56.524936Z",
     "iopub.status.idle": "2025-09-03T13:19:56.721404Z",
     "shell.execute_reply": "2025-09-03T13:19:56.720565Z",
     "shell.execute_reply.started": "2025-09-03T13:19:56.525536Z"
    }
   },
   "outputs": [],
   "source": [
    "job_name = client.train(\n",
    "    trainer=CustomTrainer(\n",
    "        func=train_fashion_mnist,\n",
    "        # Set how many PyTorch nodes you want to use for distributed training.\n",
    "        num_nodes=3,\n",
    "        # Set the resources for each PyTorch node.\n",
    "        resources_per_node={\n",
    "            \"cpu\": 3,\n",
    "            \"memory\": \"16Gi\",\n",
    "            # Uncomment this to distribute the TrainJob using GPU nodes.\n",
    "            # \"nvidia.com/gpu\": 1,\n",
    "        },\n",
    "    ),\n",
    "    runtime=torch_runtime,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check the TrainJob steps\n",
    "\n",
    "You can check the components of TrainJob that's created.\n",
    "\n",
    "Since the TrainJob performs distributed training across 3 nodes, it generates 3 steps: `trainer-node-0` .. `trainer-node-2`.\n",
    "\n",
    "You can get the individual status for each of these steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-03T13:20:01.378158Z",
     "iopub.status.busy": "2025-09-03T13:20:01.377707Z",
     "iopub.status.idle": "2025-09-03T13:20:12.713960Z",
     "shell.execute_reply": "2025-09-03T13:20:12.713295Z",
     "shell.execute_reply.started": "2025-09-03T13:20:01.378130Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TrainJob(name='j819fe212f8e', creation_timestamp=datetime.datetime(2025, 9, 3, 13, 19, 56, tzinfo=TzInfo(UTC)), runtime=Runtime(name='torch-distributed', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='torch', num_nodes=1, device='Unknown', device_count='Unknown'), pretrained_model=None), steps=[Step(name='node-0', status='Running', pod_name='j819fe212f8e-node-0-0-nlcqt', device='gpu', device_count='1'), Step(name='node-1', status='Running', pod_name='j819fe212f8e-node-0-1-rjztc', device='gpu', device_count='1'), Step(name='node-2', status='Running', pod_name='j819fe212f8e-node-0-2-l575c', device='gpu', device_count='1')], num_nodes=3, status='Running')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Wait for the running status.\n",
    "client.wait_for_job_status(name=job_name, status={\"Running\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-03T13:20:24.045774Z",
     "iopub.status.busy": "2025-09-03T13:20:24.045480Z",
     "iopub.status.idle": "2025-09-03T13:20:24.772877Z",
     "shell.execute_reply": "2025-09-03T13:20:24.772178Z",
     "shell.execute_reply.started": "2025-09-03T13:20:24.045755Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step: node-0, Status: Running, Devices: gpu x 1\n",
      "\n",
      "Step: node-1, Status: Running, Devices: gpu x 1\n",
      "\n",
      "Step: node-2, Status: Running, Devices: gpu x 1\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for c in client.get_job(name=job_name).steps:\n",
    "    print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Watch the TrainJob logs\n",
    "\n",
    "We can use the `get_job_logs()` API to get the TrainJob logs.\n",
    "\n",
    "Since we run training on 3 GPUs, every PyTorch node uses 60,000/3 = 20,000 images from the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-03T13:20:26.729486Z",
     "iopub.status.busy": "2025-09-03T13:20:26.728951Z",
     "iopub.status.idle": "2025-09-03T13:20:29.596510Z",
     "shell.execute_reply": "2025-09-03T13:20:29.594741Z",
     "shell.execute_reply.started": "2025-09-03T13:20:26.729446Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[W903 13:20:06.250970596 socket.cpp:755] [c10d] The IPv6 network addresses of (j819fe212f8e-node-0-0.j819fe212f8e, 29500) cannot be retrieved (gai error: -2 - Name or service not known).\n",
      "[W903 13:20:07.936850072 socket.cpp:755] [c10d] The IPv6 network addresses of (j819fe212f8e-node-0-0.j819fe212f8e, 29500) cannot be retrieved (gai error: -2 - Name or service not known).\n",
      "[W903 13:20:08.873279869 socket.cpp:755] [c10d] The IPv6 network addresses of (j819fe212f8e-node-0-0.j819fe212f8e, 29500) cannot be retrieved (gai error: -2 - Name or service not known).\n",
      "[W903 13:20:08.673190939 socket.cpp:755] [c10d] The IPv6 network addresses of (j819fe212f8e-node-0-0.j819fe212f8e, 29500) cannot be retrieved (gai error: -2 - Name or service not known).\n",
      "[W903 13:20:09.683786251 socket.cpp:755] [c10d] The IPv6 network addresses of (j819fe212f8e-node-0-0.j819fe212f8e, 29500) cannot be retrieved (gai error: -2 - Name or service not known).\n",
      "Using Device: cuda, Backend: nccl\n",
      "Distributed Training for WORLD_SIZE: 3, RANK: 0, LOCAL_RANK: 0\n",
      "100%|██████████| 26.4M/26.4M [00:02<00:00, 12.2MB/s]\n",
      "100%|██████████| 29.5k/29.5k [00:00<00:00, 213kB/s]\n",
      "100%|██████████| 4.42M/4.42M [00:01<00:00, 3.85MB/s]\n",
      "100%|██████████| 5.15k/5.15k [00:00<00:00, 43.7MB/s]\n",
      "/opt/conda/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py:4631: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. \n",
      "  warnings.warn(  # warn only once\n",
      "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.297219\n",
      "Train Epoch: 1 [1000/60000 (5%)]\tLoss: 1.803301\n",
      "Train Epoch: 1 [2000/60000 (10%)]\tLoss: 1.663171\n",
      "Train Epoch: 1 [3000/60000 (15%)]\tLoss: 1.462810\n",
      "Train Epoch: 1 [4000/60000 (20%)]\tLoss: 1.270390\n",
      "Train Epoch: 1 [5000/60000 (25%)]\tLoss: 0.977784\n",
      "Train Epoch: 1 [6000/60000 (30%)]\tLoss: 0.791962\n",
      "Train Epoch: 1 [7000/60000 (35%)]\tLoss: 0.626636\n",
      "Train Epoch: 1 [8000/60000 (40%)]\tLoss: 0.613735\n",
      "Train Epoch: 1 [9000/60000 (45%)]\tLoss: 0.590524\n",
      "Train Epoch: 1 [10000/60000 (50%)]\tLoss: 0.529664\n",
      "Train Epoch: 1 [11000/60000 (55%)]\tLoss: 0.448692\n",
      "Train Epoch: 1 [12000/60000 (60%)]\tLoss: 0.508065\n",
      "Train Epoch: 1 [13000/60000 (65%)]\tLoss: 0.519867\n",
      "Train Epoch: 1 [14000/60000 (70%)]\tLoss: 0.485550\n",
      "Train Epoch: 1 [15000/60000 (75%)]\tLoss: 0.377857\n",
      "Train Epoch: 1 [16000/60000 (80%)]\tLoss: 0.456072\n",
      "Train Epoch: 1 [17000/60000 (85%)]\tLoss: 0.563855\n",
      "Train Epoch: 1 [18000/60000 (90%)]\tLoss: 0.426923\n",
      "Train Epoch: 1 [19000/60000 (95%)]\tLoss: 0.379981\n",
      "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.416116\n",
      "Train Epoch: 2 [1000/60000 (5%)]\tLoss: 0.506344\n",
      "Train Epoch: 2 [2000/60000 (10%)]\tLoss: 0.510181\n",
      "Train Epoch: 2 [3000/60000 (15%)]\tLoss: 0.388432\n",
      "Train Epoch: 2 [4000/60000 (20%)]\tLoss: 0.361425\n",
      "Train Epoch: 2 [5000/60000 (25%)]\tLoss: 0.428971\n",
      "Train Epoch: 2 [6000/60000 (30%)]\tLoss: 0.409087\n",
      "Train Epoch: 2 [7000/60000 (35%)]\tLoss: 0.423781\n",
      "Train Epoch: 2 [8000/60000 (40%)]\tLoss: 0.530143\n",
      "Train Epoch: 2 [9000/60000 (45%)]\tLoss: 0.390188\n",
      "Train Epoch: 2 [10000/60000 (50%)]\tLoss: 0.359034\n",
      "Train Epoch: 2 [11000/60000 (55%)]\tLoss: 0.353203\n",
      "Train Epoch: 2 [12000/60000 (60%)]\tLoss: 0.394956\n",
      "Train Epoch: 2 [13000/60000 (65%)]\tLoss: 0.407927\n",
      "Train Epoch: 2 [14000/60000 (70%)]\tLoss: 0.296730\n",
      "Train Epoch: 2 [15000/60000 (75%)]\tLoss: 0.318915\n",
      "Train Epoch: 2 [16000/60000 (80%)]\tLoss: 0.317562\n",
      "Train Epoch: 2 [17000/60000 (85%)]\tLoss: 0.396108\n",
      "Train Epoch: 2 [18000/60000 (90%)]\tLoss: 0.302800\n",
      "Train Epoch: 2 [19000/60000 (95%)]\tLoss: 0.364170\n",
      "Training is finished\n"
     ]
    }
   ],
   "source": [
    "for logline in client.get_job_logs(job_name, follow=True):\n",
    "    print(logline)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Delete the TrainJob\n",
    "\n",
    "When TrainJob is finished, you can delete the resource.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# client.delete_job(job_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
